import jax
import jax.numpy as jnp
import json
import argparse
import os, sys

root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'sae-jax'))  
sys.path.insert(0, root)

root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'lsh'))  
sys.path.insert(0, root)

root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'sae-softmax'))  
sys.path.insert(0, root)

from lsh import GemmaLSH

from sae_save_load import (
    save_model, 
    load_model, 
    save_checkpoint, 
    load_checkpoint,
    save_metadata,
    load_jax_sae_to_pytorch,
    encode_sparse_torch
)


from functools import partial
import numpy as np
from typing import Union, List

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from dual import GemmaEmbeddingPredictor

from eval_misc import process_batch, get_sparse_representations_and_reconstructions, find_top_k_embeddings_cosine_similarity, SparseCodeMatcher, compute_kl_divergence, compute_top_k_overlap, quarter_sentence
from datasets import load_dataset


def parse_args():
    parser = argparse.ArgumentParser(description='Evaluate SAE-aware token prediction')
    parser.add_argument('--sae_model_path', type=str, 
                      default='~/gemma-7b-sae/k5_whole_sae_final_model.pkl',
                      help='Path to the SAE model')
    parser.add_argument('--sae_code_path', type=str,
                      default='~/gemma-7b-sae/k5_whole_sae_final_z.npy',
                      help='Path to the SAE code')
    parser.add_argument('--mlp_model_path', type=str,
                      default="~/dual-map/dual_map_mlp_model_small_rescaled_experimental.pt",
                      help='Path to the MLP model')
    parser.add_argument('--prompt', type=str, default=None,
                      help='Custom prompt to use for generation. If provided, will not use dataset.')
    parser.add_argument('--temperature', type=float, default=1.0,
                      help='Temperature for generation (default: 1.0)')
    parser.add_argument('--top_k', type=int, default=50,
                      help='Top-k sampling parameter (default: 50)')
    parser.add_argument('--top_p', type=float, default=0.9,
                      help='Top-p (nucleus) sampling parameter (default: 0.9)')
    parser.add_argument('--max_new_tokens', type=int, default=300,
                      help='Maximum number of new tokens to generate (default: 50)')
    parser.add_argument('--method', type=str, default="both", choices=["greedy", "top_k", "top_p", "both"],
                      help='Generation method: greedy, top_k, top_p, or both (default: both)')
    parser.add_argument('--model_name', type=str, default="google/gemma-7b",
                      help='Name of the base model to use')
    parser.add_argument('--dataset_name', type=str, default="bookcorpus",
                      help='Name of the dataset to use')
    parser.add_argument('--total_samples', type=int, default=1000,
                      help='Total number of samples to evaluate')
    parser.add_argument('--save_every', type=int, default=50,
                      help='Save results every N samples')
    parser.add_argument('--cache_dir', type=str, default="~/gemma_cache",
                      help='Directory to cache model files')
    parser.add_argument('--output_dir', type=str, 
                      default=None,
                      help='Base directory to save output files. If not specified, will use ~/results/sae-softmax/{model_name}')

    return parser.parse_args()

# Parse command line arguments
args = parse_args()

# Set default output directory if not specified
if args.output_dir is None:
    model_name_short = args.model_name.split('/')[-1]  # Get just the model name without org
    args.output_dir = f"~/results/sae-softmax/{model_name_short}"

# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)

sae_model = load_jax_sae_to_pytorch(args.sae_model_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gemma_tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)
gemma_model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=args.cache_dir).to(device)

vocab_dict = gemma_tokenizer.get_vocab()
vocab_list = ["<unused>"] * (max(vocab_dict.values()) + 1)
for word, index in vocab_dict.items():
    vocab_list[index] = word
    
z = np.load(args.sae_code_path)
output_embeddings = gemma_model.get_output_embeddings().weight.to(device)
input_dim = output_embeddings.shape[1]

class NextToken:
    def __init__(self, z, output_embeddings, mlp_model_path, sae_model, top_n_canidates=4000):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.sparsecode = SparseCodeMatcher(z)
        self.original_g = output_embeddings.to(self.device)
        self.input_dim = self.original_g.shape[1]
    
        self.mean = self.original_g.mean(axis=0)
        original_g_centered = self.original_g - self.mean
        u, s, vt = torch.linalg.svd(original_g_centered, full_matrices=False)
        
        self.whitening_matrix = torch.matmul(
                torch.matmul(vt.T, torch.diag(1.0 / torch.sqrt(s + 1e-6))),
                vt
            )

        self.embedpredictor = GemmaEmbeddingPredictor(input_dim=self.input_dim, mlp_model_path=mlp_model_path)
        self.sae_model = sae_model
        self.top_n_canidates = top_n_canidates

    def get_next_logits(self, 
                        last_token_embedding, 
                        method = "default", top_candidates=200000):

        with torch.no_grad():
            last_token_embedding = last_token_embedding.to(self.device)
            next_token_logits = last_token_embedding @ self.original_g.T
            
            if method == "default":
                next_token_logits = last_token_embedding @ self.original_g.T
                return next_token_logits, 0
    
            if method == "full" or method == "full+sparselookup":
                next_token_probs = torch.softmax(next_token_logits, dim=0)
                expected_unembedding = torch.matmul(next_token_probs, self.original_g)
                
                test_embedding = (expected_unembedding - self.mean) @ self.whitening_matrix
                test_embedding = test_embedding.detach()
                
            if method == "approxi" or method == "approxi+sparselookup":
                test_embedding = self.embedpredictor.predict_next_token_embedding(last_token_embedding.to(self.device))
                
            test_z = encode_sparse_torch(self.sae_model, test_embedding)
                
            if method == "approxi" or method == "full":
                top_indices_cpu = find_top_k_embeddings_cosine_similarity(z, test_z, k=self.top_n_canidates)
            if method == "approxi+sparselookup":
                top_indices_cpu = self.sparsecode.retrieve_similar_codes(test_z, max_codes=top_candidates)

            device = next_token_logits.device
            mask = torch.ones_like(next_token_logits, dtype=torch.bool)
            mask[top_indices_cpu] = False                                     # False = keep, True = mask out
            
            # set masked positions to -inf
            next_token_logits = next_token_logits.masked_fill(mask, float("-inf"))
    
            return next_token_logits, len(top_indices_cpu)
                                

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """
    logits: Tensor of shape [vocab] or [batch, vocab]
    """
    # make everything 2D for uniformity
    was_1d = (logits.dim() == 1)
    if was_1d:
        logits = logits.unsqueeze(0)  # [1, vocab]

    # Top-K filtering
    if top_k > 0:
        # get top_k values and their kth threshold per row
        topk_vals, _ = torch.topk(logits, top_k, dim=-1)  # [batch, top_k]
        kth_vals = topk_vals[..., -1].unsqueeze(-1)        # [batch, 1]
        logits = torch.where(logits < kth_vals, filter_value, logits)

    # Top-P (nucleus) filtering
    if top_p > 0.0:
        # sort descending
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        # compute cumulative probs on sorted logits
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # mask out any token above the cumulative threshold
        sorted_mask = cumulative_probs > top_p
        # always keep the first token
        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
        sorted_mask[..., 0] = False

        # scatter sorted_mask back to the original logits shape
        mask = torch.zeros_like(logits, dtype=torch.bool)             # [batch, vocab]
        mask.scatter_(dim=-1, index=sorted_indices, src=sorted_mask)
        logits = logits.masked_fill(mask, filter_value)

    # squeeze back if needed
    if was_1d:
        return logits.squeeze(0)  # [vocab]
    return logits              # [batch, vocab]


def generate_manual(
    model, tokenizer, 
    prompt: Union[str, List[str]],
    nextokenpred = None,
    nextokenpred_methpd = "default",
    max_new_tokens: int = 5, 
    temperature: float = 1.0, 
    top_k: int = 0, 
    top_p: float = 0.0, 
    method: str = "greedy",  # "greedy", "top_k", "top_p", or "both"
):
    device = next(model.parameters()).device
    model.eval()

    # Tokenize and move to device
    tokenized = tokenizer(
        prompt, 
        return_tensors="pt",
        padding=True,             # pad a batch of prompts
        truncation=True
    ).to(device)

    generated = tokenized.input_ids
    attention_mask = tokenized.attention_mask.to(device)


    batch_size = generated.size(0)
    finished = torch.zeros(batch_size, dtype=torch.bool, device=device)

    for _ in range(max_new_tokens):
        # Get logits for the last token
        with torch.no_grad():
            outputs = model(input_ids=generated, attention_mask=attention_mask, output_hidden_states=True)
            if nextokenpred is not None:
                last_hidden_state = outputs.hidden_states[-1]
                next_token_logits = outputs.logits[:, -1, :]
                for i in range(batch_size):
                    last_token_embedding = last_hidden_state[i, -1, :]
                    next_token_logits_per_sequence, _ = nextokenpred.get_next_logits(last_token_embedding, 
                            method = nextokenpred_methpd) 
                    next_token_logits_per_sequence = next_token_logits_per_sequence / temperature
                    next_token_logits[i, :] = next_token_logits_per_sequence
            else:
                next_token_logits = outputs.logits[:, -1, :] / temperature

        if method == "greedy":
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        else:
            # Apply top-k and/or top-p filtering
            filtered_logits = top_k_top_p_filtering(
                next_token_logits, 
                top_k=top_k if method in ("top_k", "both") else 0, 
                top_p=top_p if method in ("top_p", "both") else 0.0
            )
            # Sample
            probs = F.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

        next_token = torch.where(
            finished.unsqueeze(1),
            torch.full_like(next_token, tokenizer.eos_token_id),
            next_token
        )

        # update finished flags
        finished |= next_token.squeeze(1) == tokenizer.eos_token_id

        # Append and continue
        generated = torch.cat((generated, next_token), dim=1)
        attention_mask = torch.cat(
            (attention_mask, torch.ones_like(next_token)),
            dim=1
        )
            
        if finished.all():
            break

    return [tokenizer.decode(g, skip_special_tokens=True) for g in generated]

nextokenpred = NextToken(z, output_embeddings, args.mlp_model_path, sae_model)

if args.prompt is not None:
    # Use the provided prompt
    prompt = args.prompt
    print("Using custom prompt:", prompt)
    print("\nDefault:   \n\n", generate_manual(
        gemma_model, 
        gemma_tokenizer, 
        prompt, 
        nextokenpred=nextokenpred, 
        nextokenpred_methpd="default", 
        method=args.method,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        max_new_tokens=args.max_new_tokens
    )[0])
    print("\nNew:   \n\n", generate_manual(
        gemma_model, 
        gemma_tokenizer, 
        prompt, 
        nextokenpred=nextokenpred, 
        nextokenpred_methpd="approxi+sparselookup", 
        method=args.method,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        max_new_tokens=args.max_new_tokens
    )[0])
else:
    # Use the dataset as before
    if args.dataset_name.endswith(".jsonl"):
        ds = load_dataset("json", data_files=args.dataset_name)
        ds = ds["train"]
    else:
        ds = load_dataset(args.dataset_name, split="train")
    total_samples = args.total_samples
    save_every = args.save_every  # Save every N samples

    from tqdm import tqdm

    cnt = 0
    results = []
    for sentence in tqdm(ds["text"]):
        prompt = quarter_sentence(sentence)
        default_output = generate_manual(
            gemma_model, 
            gemma_tokenizer, 
            prompt, 
            nextokenpred=nextokenpred, 
            nextokenpred_methpd="default", 
            method=args.method,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            max_new_tokens=args.max_new_tokens
        )
        default_output = default_output[0]

        new_output = generate_manual(
            gemma_model, 
            gemma_tokenizer, 
            prompt, 
            nextokenpred=nextokenpred, 
            nextokenpred_methpd="approxi+sparselookup", 
            method=args.method,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            max_new_tokens=args.max_new_tokens
        )
        new_output = new_output[0]
        results.append({
            "prompt": prompt,
            "default": default_output,
            "new": new_output
        })
        
        print(f"\n\nPrompt:\n{prompt}")
        print(f"\n\nDefault:\n{default_output}", )
        print(f"\n\nNew:\n{new_output}")
    
    output_filename = os.path.join(args.output_dir, f"samples-{args.method}-temp-{args.temperature}-max-new-tokens-{args.max_new_tokens}.jsonl")
    with open(output_filename, "w") as f:
        for result in results:
            f.write(json.dumps(result) + "\n")